Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Sep 22, 2025

Description

This PR adds max_logit support in FusedAttention and UnfusedDotProductAttention backends in TE-PyTorch. max_logit is the max of mask(Q x K^T x scale + bias), and it is used by MuonClip optimizer to rescale Q and K projection weights.

This PR supports FP16, BF16 precisions and BSHD, SBHD formats. It supports non-CP and CP (cp_comm_type = {"p2p", "a2a", "a2a+p2p", "all_gather"}) cases.

It contains a breaking change: adding return_max_logit to the nvte_get_fused_attn_backend, nvte_fused_attn_fwd and nvte_fused_attn_bwd. TE will pack up the tensor and non-tensor arguments in these APIs as structs in the future, in order to avoid breaking changes like this.

The support for THD is also implemented in this PR and will be enabled when cuDNN supports it.

This PR requires cuDNN 9.13.1 and cudnn-frontend 1.15.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add max_logit support in FusedAttention and UnfusedDotProductAttention

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cyanguwa cyanguwa added the 2.9.0 label Sep 22, 2025
cyanguwa and others added 17 commits September 30, 2025 11:22
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa requested a review from ptrendx October 15, 2025 13:52
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@BoxiangW
Copy link
Contributor

LGTM thx

@cyanguwa cyanguwa requested a review from skyw October 22, 2025 08:36
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

Copy link

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BoxiangW has reviewed and LGTM.

Approving.

@BoxiangW
Copy link
Contributor

BoxiangW commented Oct 22, 2025

One more thing on this PR, I think we agreed before on changing the naming into max_logit or max_qk_logit since it represent the value before softmax op

@cyanguwa cyanguwa changed the title [PyTorch] Add max_score support for MuonClip [PyTorch] Add max_logit support for MuonClip Oct 23, 2025
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa requested a review from mk-61 October 23, 2025 15:17
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@vcherepanov-nv
Copy link
Collaborator

LGTM

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa cyanguwa merged commit 87cb26c into NVIDIA:main Oct 25, 2025
47 of 53 checks passed
KshitijLakhani pushed a commit that referenced this pull request Oct 28, 2025
* add max_score for fused/unfused F16 non-CP

Signed-off-by: Charlene Yang <[email protected]>

* calculate max per head instead of max over all heads

Signed-off-by: Charlene Yang <[email protected]>

* fix fused attn max_score shape

Signed-off-by: Charlene Yang <[email protected]>

* revert FE to github

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update FE to 1.15.0-rc

Signed-off-by: Charlene Yang <[email protected]>

* fix merge

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reduce ew kernels; fix causal masks; add more tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix to tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove logic for flash-attn

Signed-off-by: Charlene Yang <[email protected]>

* WIP: add CP support for p2p/a2a/all_gather

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor improvements of implementation/tests

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* WIP: add thd support

Signed-off-by: Charlene Yang <[email protected]>

* add thd to UnfusedDPA

Signed-off-by: Charlene Yang <[email protected]>

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* more fixes for lint

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update to FE 1.15

Signed-off-by: Charlene Yang <[email protected]>

* remove unneeded changes

Signed-off-by: Charlene Yang <[email protected]>

* disable unfused for thd + pad_between_seqs

Signed-off-by: Charlene Yang <[email protected]>

* minor fixes

Signed-off-by: Charlene Yang <[email protected]>

* disable thd for unfused until bug is fixed

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix all_gather

Signed-off-by: Charlene Yang <[email protected]>

* fix all gather

Signed-off-by: Charlene Yang <[email protected]>

* rename max_score to max_logit

Signed-off-by: Charlene Yang <[email protected]>

* fix all_gather

Signed-off-by: Charlene Yang <[email protected]>

* fix all_gather

Signed-off-by: Charlene Yang <[email protected]>

* disable fused attn + thd

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants